{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": [
    {
     "file_id": "1Ywx5qv0bNB1HNfcTdRBFmZmxza7qFLvb",
     "timestamp": 1717383158529
    }
   ],
   "toc_visible": true,
   "gpuType": "T4",
   "authorship_tag": "ABX9TyNIHYjpenSao+xcehmyCSlQ"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "##### Copyright 2024 Google LLC."
   ],
   "metadata": {
    "id": "nH85BOCo7YYk"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "9tQNAByc7U9g"
   },
   "outputs": [],
   "source": [
    "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Play with AI - Guess the word\n",
    "\n",
    "This cookbook illustrates how you can employ the instruction-tuned model version of Gemma as a chatbot to play \"Guess the word\" game.\n",
    "\n",
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Guess_the_word.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "</table>"
   ],
   "metadata": {
    "id": "F7r2q0wS7bxf"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Setup\n",
    "\n",
    "### Select the Colab runtime\n",
    "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
    "\n",
    "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
    "2. Select **Change runtime type**.\n",
    "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
    "\n",
    "\n",
    "### Gemma setup on Kaggle\n",
    "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
    "\n",
    "* Get access to Gemma on kaggle.com.\n",
    "* Select a Colab runtime with sufficient resources to run the Gemma 2B model.\n",
    "* Generate and configure a Kaggle username and API key.\n",
    "\n",
    "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
   ],
   "metadata": {
    "id": "ZHrL4tqs7mYK"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Set environment variables\n",
    "\n",
    "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
   ],
   "metadata": {
    "id": "pQEE8RoO75F-"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "import os\n",
    "from google.colab import userdata\n",
    "\n",
    "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
    "# vars as appropriate for your system.\n",
    "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
    "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
   ],
   "metadata": {
    "id": "XsY2Ut7a76Wa"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Install dependencies\n",
    "\n",
    "Install Keras and KerasNLP."
   ],
   "metadata": {
    "id": "Ea_56Zpa78Gu"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
    "!pip install -q -U keras-nlp\n",
    "!pip install -q -U keras"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AxPjbcnC79ck",
    "executionInfo": {
     "status": "ok",
     "timestamp": 1717382922185,
     "user_tz": -540,
     "elapsed": 77569,
     "user": {
      "displayName": "Ju-yeong Ji",
      "userId": "07739812480064137731"
     }
    },
    "outputId": "7ff69f49-bdb2-4d13-95d5-bb635a6d73d9"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/570.5 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.4/570.5 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m570.5/570.5 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m33.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m61.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m589.8/589.8 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m47.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m59.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m66.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m67.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m30.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.1 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0m"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Create a chat helper to manage the conversation state"
   ],
   "metadata": {
    "id": "a_QCPQLf8OU0"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "import re\n",
    "\n",
    "import keras\n",
    "import keras_nlp\n",
    "\n",
    "# Run at half precision to fit in memory\n",
    "keras.config.set_floatx(\"bfloat16\")\n",
    "\n",
    "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_1.1_instruct_2b_en\")\n",
    "gemma_lm.compile(sampler=\"top_k\")\n",
    "\n",
    "\n",
    "class GemmaChat:\n",
    "\n",
    "    __START_TURN__ = \"<start_of_turn>\"\n",
    "    __END_TURN__ = \"<end_of_turn>\"\n",
    "\n",
    "    def __init__(self, model, system=\"\", history=None):\n",
    "        self.model = model\n",
    "        self.system = system\n",
    "        if not history:\n",
    "            self.history = []\n",
    "        else:\n",
    "            self.history = history\n",
    "\n",
    "    def format_message(self, message, prefix=\"\"):\n",
    "        return f\"{self.__START_TURN__}{prefix}\\n{message}{self.__END_TURN__}\\n\"\n",
    "\n",
    "    def add_to_history(self, message, prefix=\"\"):\n",
    "        formated_message = self.format_message(message, prefix)\n",
    "        self.history.append(formated_message)\n",
    "\n",
    "    def get_full_prompt(self):\n",
    "        prompt = \"\\n\".join([self.system + \"\\n\", *self.history])\n",
    "        return prompt\n",
    "\n",
    "    def send_message(self, message):\n",
    "        self.add_to_history(message, \"user\")\n",
    "        prompt = self.get_full_prompt()\n",
    "        response = self.model.generate(prompt, max_length=4096)\n",
    "        # print('--------------------response---------------------------')\n",
    "        # print(response)\n",
    "        # print('--------------------response---------------------------')\n",
    "        result = response.replace(prompt, \"\")\n",
    "        self.add_to_history(result, \"model\")\n",
    "        return result\n",
    "\n",
    "    def show_history(self):\n",
    "        for h in self.history:\n",
    "            print(h)\n",
    "\n",
    "\n",
    "chat = GemmaChat(gemma_lm)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2BmB5Zua8Vs0",
    "executionInfo": {
     "status": "ok",
     "timestamp": 1717383020460,
     "user_tz": -540,
     "elapsed": 36125,
     "user": {
      "displayName": "Ju-yeong Ji",
      "userId": "07739812480064137731"
     }
    },
    "outputId": "cb4d7013-0353-40d0-83c8-93dab9bdc06d"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
      "Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Play the game"
   ],
   "metadata": {
    "id": "_1jyCoRd8EwX"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "theme = input(\"Choose your theme: \")\n",
    "setup_message = f\"Generate a random single word from {theme}.\"\n",
    "\n",
    "chat.history.clear()\n",
    "answer = chat.send_message(setup_message).split()[0]\n",
    "answer = re.sub(r\"\\W+\", \"\", answer)  # excludes all numbers, letters and '_'\n",
    "chat.history.clear()\n",
    "cmd_exit = \"quit\"\n",
    "question = f'Describe the word \"{answer}\" without saying it.'\n",
    "\n",
    "resp = \"\"\n",
    "while resp.lower() != answer.lower() and resp != cmd_exit:\n",
    "    text = chat.send_message(question)\n",
    "    if resp == \"\":\n",
    "        print(f'Guess what I\\'m thinking.\\nType \"{cmd_exit}\" if you want to quit.')\n",
    "    remove_answer = re.compile(re.escape(answer), re.IGNORECASE)\n",
    "    text = remove_answer.sub(\"XXXX\", text)\n",
    "    print(text)\n",
    "    resp = input(\"\\n> \")\n",
    "\n",
    "if resp == cmd_exit:\n",
    "    print(f\"The answer was {answer}.\\n\")\n",
    "else:\n",
    "    print(\"Correct!\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "zoWDt87V83rZ",
    "executionInfo": {
     "status": "ok",
     "timestamp": 1717383101577,
     "user_tz": -540,
     "elapsed": 78281,
     "user": {
      "displayName": "Ju-yeong Ji",
      "userId": "07739812480064137731"
     }
    },
    "outputId": "6240bd4e-ceab-4866-d9cb-ac1e56ac7aee"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Choose your theme: animal\n",
      "Guess what I'm thinking.\n",
      "Type \"quit\" if you want to quit.\n",
      "The word evokes powerful images in the minds of many, a symbol of strength, courage, and dominance.\n",
      "\n",
      "> tiger\n",
      "**Answer:** A creature with a powerful and majestic presence, known for its courage, strength, and fierce nature.\n",
      "\n",
      "> lion\n",
      "Correct!\n"
     ]
    }
   ]
  }
 ]
}